from difflib import SequenceMatcher

import numpy as np

from .new_alignment import ScoreParam, SeqGraphAlignment

PUNCTUATION_MARKS = [".", "!", "?", ",", ":", ";", "...", "(", ")"]

class TextSeqGraphAlignment(SeqGraphAlignment):
    def __init__(
        self,
        text,
        graph,
        fastMethod=True,
        globalAlign=True,
        matchscore=1,
        mismatchscore=-3,
        gap_open=-2,
        gap_extend=-1,
        position_weight=0.1,
        *args,
        **kwargs,
    ):
        score_params = ScoreParam(
            match=matchscore, mismatch=mismatchscore, gap_open=gap_open, gap_extend=gap_extend
        )

        if isinstance(text, str):
            self.original_text = text
            self.sequence = text.split()
        else:
            self.sequence = text
            self.original_text = " ".join(text)
        self.position_weight = position_weight

        super().__init__(
            self.sequence,
            graph,
            fastMethod,
            globalAlign=globalAlign,
            score_params=score_params,
            *args,
            **kwargs,
        )

    def string_similarity(self, s1, s2):
        """Get edit-distance based similarity between two strings"""
        return SequenceMatcher(None, s1, s2).ratio()

    def matchscore(self, word1: str, word2: str) -> float:
        """Enhanced scoring function that considers string similarity
        and relative position"""
        # Calculate basic string similarity
        similarity = self.string_similarity(word1, word2)

        # If words are very similar, treat as match
        if similarity > 0.8:  # Can tune this threshold
            similarity = self.score.match
        # For less similar words, scale score based on similarity
        elif similarity > 0.5:  # Can tune this threshold too
            similarity = self.score.match * similarity
        else:
            similarity = self.score.mismatch
            return similarity

        # add weight if any punctuation mark is present
        if any(char in word1 for char in PUNCTUATION_MARKS) or any(
            char in word2 for char in PUNCTUATION_MARKS
        ):
            similarity = similarity * 1.5

        return similarity

    def alignmentStrings(self):
        """Override to handle word-based alignment"""
        aligned_seq = [self.sequence[i] if i is not None else "-" for i in self.stringidxs]
        aligned_graph = [
            self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs
        ]
        return " ".join(aligned_seq), " ".join(aligned_graph)

    def alignStringToGraphFast(self):
        if not isinstance(self.sequence, list):
            raise TypeError("Sequence must be a list of words")

        nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx = (
            self.initializeDynamicProgrammingData()
        )
        # M: Match at last indices, X: Gap at last index of graph, Y: gap at last index of sequence
        M, X, Y = 0, 1, 2
        
        ni = self.graph.nodeiterator()
        for i, node in enumerate(ni()):
            gbase = node.text
            
            for j, sbase in enumerate(self.sequence):
                candidates_X , candidates_Y , candidates_M = [], [], []
                candidates_X += [
                    (self.score.gap_open + self.score.gap_extend + scores[0, i + 1, j], i + 1, j, M),
                    (self.score.gap_extend + scores[1, i + 1, j], i + 1, j, X),
                    (self.score.gap_open + self.score.gap_extend + scores[2, i + 1, j], i + 1, j, Y)
                ]
                for predIndex in self.prevIndices(node, nodeIDtoIndex):
                    candidates_Y += [
                        (self.score.gap_open + self.score.gap_extend + scores[0, predIndex + 1, j + 1] , predIndex + 1, j + 1, M),
                        (self.score.gap_open + self.score.gap_extend + scores[1, predIndex + 1, j + 1] , predIndex + 1, j + 1, X),
                        (self.score.gap_extend + scores[2, predIndex + 1, j + 1] , predIndex + 1, j + 1, Y)
                    ]
                    candidates_M += [
                        (self.matchscore(sbase, gbase) +  scores[0, predIndex + 1, j], predIndex + 1, j, M),
                        (self.matchscore(sbase, gbase) +  scores[1, predIndex + 1, j], predIndex + 1, j, X),
                        (self.matchscore(sbase, gbase) +  scores[2, predIndex + 1, j], predIndex + 1, j, Y)
                    ]
                
                (
                    scores[0, i + 1, j + 1],
                    backGrphIdx[0, i + 1, j + 1],
                    backStrIdx[0, i + 1, j + 1],
                    backMtxIdx[0, i + 1, j + 1],
                ) = max(candidates_M)
                (
                    scores[1, i + 1, j + 1],
                    backGrphIdx[1, i + 1, j + 1],
                    backStrIdx[1, i + 1, j + 1],
                    backMtxIdx[1, i + 1, j + 1],
                ) = max(candidates_X)
                (
                    scores[2, i + 1, j + 1],
                    backGrphIdx[2, i + 1, j + 1],
                    backStrIdx[2, i + 1, j + 1],
                    backMtxIdx[2, i + 1, j + 1],
                ) = max(candidates_Y)        
                
        return self.backtrack(scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID)
